# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0

import os
from unittest.mock import MagicMock, patch

import numpy as np
import pytest

from haystack import Document, Pipeline
from haystack.components.embedders import SentenceTransformersTextEmbedder
from haystack.components.embedders.sentence_transformers_document_embedder import SentenceTransformersDocumentEmbedder
from haystack.components.generators.chat import OpenAIChatGenerator
from haystack.components.query import QueryExpander
from haystack.components.retrievers import InMemoryEmbeddingRetriever, MultiQueryEmbeddingRetriever
from haystack.components.writers import DocumentWriter
from haystack.document_stores.in_memory import InMemoryDocumentStore
from haystack.document_stores.types import DuplicatePolicy


class TestMultiQueryEmbeddingRetriever:
    @pytest.fixture
    def sample_documents(self):
        return [
            Document(
                content="Renewable energy is energy that is collected from renewable resources.",
                meta={"category": None},
            ),
            Document(
                content="Solar energy is a type of green energy that is harnessed from the sun.",
                meta={"category": "solar"},
            ),
            Document(
                content="Wind energy is another type of green energy that is generated by wind turbines",
                meta={"category": "wind"},
            ),
            Document(
                content="Hydropower is a form of renewable energy using the flow of water to generate electricity.",
                meta={"category": "hydro"},
            ),
            Document(
                content="Geothermal energy is heat that comes from the sub-surface of the earth.",
                meta={"category": "geo"},
            ),
            Document(
                content="Fossil fuels, such as coal, oil, and natural gas, are non-renewable energy sources.",
                meta={"category": "fossil"},
            ),
            Document(
                content="Nuclear energy is produced through nuclear reactions, typically using uranium or "
                "plutonium as fuel.",
                meta={"category": "nuclear"},
            ),
        ]

    @pytest.fixture
    def document_store_with_embeddings(self, sample_documents):
        """Create a document store populated with embedded documents."""
        document_store = InMemoryDocumentStore()
        doc_embedder = SentenceTransformersDocumentEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
        doc_embedder.warm_up()
        doc_writer = DocumentWriter(document_store=document_store, policy=DuplicatePolicy.SKIP)

        embedded_docs = doc_embedder.run(sample_documents)["documents"]
        doc_writer.run(documents=embedded_docs)
        return document_store

    @pytest.fixture
    def mock_query_embedder(self):
        with patch(
            "haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
        ) as mock_text_embedder:
            mock_model = MagicMock()
            mock_text_embedder.return_value = mock_model

            def mock_encode(
                texts, batch_size=None, show_progress_bar=None, normalize_embeddings=None, precision=None, **kwargs
            ):  # noqa E501
                return [np.ones(384).tolist() for _ in texts]

            mock_model.encode = mock_encode
            embedder = SentenceTransformersTextEmbedder(model="mock-model", progress_bar=False)

            def mock_run(text):
                embedding = np.ones(384).tolist()
                return {"embedding": embedding}

            embedder.run = mock_run
            embedder.warm_up()
            return embedder

    def test_init_with_default_parameters(self, mock_query_embedder):
        embedding_retriever = InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore())
        query_embedder = mock_query_embedder

        retriever = MultiQueryEmbeddingRetriever(retriever=embedding_retriever, query_embedder=query_embedder)

        assert retriever.retriever == embedding_retriever
        assert retriever.query_embedder == mock_query_embedder
        assert retriever.max_workers == 3

    def test_init_with_custom_parameters(self, mock_query_embedder):
        embedding_retriever = InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore())
        query_embedder = mock_query_embedder
        retriever = MultiQueryEmbeddingRetriever(
            retriever=embedding_retriever, query_embedder=query_embedder, max_workers=2
        )

        assert retriever.retriever == embedding_retriever
        assert retriever.query_embedder == mock_query_embedder
        assert retriever.max_workers == 2

    def test_run_with_empty_queries(self, mock_query_embedder):
        multi_retriever = MultiQueryEmbeddingRetriever(
            retriever=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()),
            query_embedder=mock_query_embedder,
        )

        result = multi_retriever.run(queries=[])

        assert "documents" in result
        assert result["documents"] == []

    def test_run_with_empty_results(self, mock_query_embedder):
        mock_query_embedder.run.return_value = {"embedding": [0.1, 0.2, 0.3]}
        multi_retriever = MultiQueryEmbeddingRetriever(
            retriever=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()),
            query_embedder=mock_query_embedder,
        )
        result = multi_retriever.run(queries=["query"])
        assert "documents" in result
        assert result["documents"] == []

    def test_to_dict(self):
        multi_retriever = MultiQueryEmbeddingRetriever(
            retriever=InMemoryEmbeddingRetriever(document_store=InMemoryDocumentStore()),
            query_embedder=SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2"),
            max_workers=2,
        )

        result = multi_retriever.to_dict()

        assert "type" in result
        assert "init_parameters" in result
        assert result["init_parameters"]["max_workers"] == 2
        assert "retriever" in result["init_parameters"]
        assert "query_embedder" in result["init_parameters"]

    def test_from_dict(self):
        data = {
            "type": "haystack.components.retrievers.multi_query_embedding_retriever.MultiQueryEmbeddingRetriever",  # noqa E501
            "init_parameters": {
                "retriever": {
                    "type": "haystack.components.retrievers.in_memory.embedding_retriever.InMemoryEmbeddingRetriever",
                    "init_parameters": {
                        "document_store": {
                            "type": "haystack.document_stores.in_memory.document_store.InMemoryDocumentStore",
                            "init_parameters": {
                                "bm25_tokenization_regex": "(?u)\\b\\w\\w+\\b",
                                "bm25_algorithm": "BM25L",
                                "bm25_parameters": {},
                                "embedding_similarity_function": "dot_product",
                                "index": "4bb5369d-779f-487b-9c16-3c40f503438b",
                                # 'return_embedding': True  # ToDo: investigate why this fails
                            },
                        },
                        "filters": None,
                        "top_k": 10,
                        "scale_score": False,
                        "return_embedding": False,
                        "filter_policy": "replace",
                    },
                },
                "query_embedder": {
                    "type": "haystack.components.embedders.sentence_transformers_text_embedder.SentenceTransformersTextEmbedder",  # noqa E501
                    "init_parameters": {
                        "model": "sentence-transformers/all-MiniLM-L6-v2",
                        "token": {"type": "env_var", "env_vars": ["HF_API_TOKEN", "HF_TOKEN"], "strict": False},
                        "prefix": "",
                        "suffix": "",
                        "batch_size": 32,
                        "progress_bar": True,
                        "normalize_embeddings": False,
                        "trust_remote_code": False,
                        "local_files_only": False,
                        "truncate_dim": None,
                        "model_kwargs": None,
                        "tokenizer_kwargs": None,
                        "config_kwargs": None,
                        "precision": "float32",
                        "encode_kwargs": None,
                        "backend": "torch",
                    },
                },
                "max_workers": 2,
            },
        }

        result = MultiQueryEmbeddingRetriever.from_dict(data)

        assert isinstance(result, MultiQueryEmbeddingRetriever)
        assert result.max_workers == 2

    def test_deduplication_with_overlapping_results(self, mock_query_embedder):
        doc1 = Document(content="Solar energy is renewable", id="doc1")
        doc1.score = 0.9
        doc2 = Document(content="Wind energy is clean", id="doc2")
        doc2.score = 0.8
        # same content as doc1 w/ different score
        doc3 = Document(content="Solar energy is renewable", id="doc3")
        doc3.score = 0.7

        # mocked retriever
        mock_retriever = MagicMock()
        call_count = 0

        def mock_retriever_run(query_embedding, **kwargs):
            nonlocal call_count
            call_count += 1
            if call_count == 1:
                return {"documents": [doc1, doc2]}
            else:
                return {"documents": [doc3, doc2]}

        mock_retriever.run = mock_retriever_run
        multi_retriever = MultiQueryEmbeddingRetriever(
            retriever=mock_retriever, query_embedder=mock_query_embedder, max_workers=1
        )
        result = multi_retriever.run(queries=["query1", "query2"])

        assert "documents" in result
        assert len(result["documents"]) == 2  # Only 2 unique documents (doc1/doc3 and doc2)

        contents = [doc.content for doc in result["documents"]]
        assert contents.count("Solar energy is renewable") == 1
        assert contents.count("Wind energy is clean") == 1

    @pytest.mark.integration
    def test_run_with_filters(self, document_store_with_embeddings):
        in_memory_retriever = InMemoryEmbeddingRetriever(document_store=document_store_with_embeddings)
        query_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
        multi_retriever = MultiQueryEmbeddingRetriever(retriever=in_memory_retriever, query_embedder=query_embedder)
        multi_retriever.warm_up()
        kwargs = {"filters": {"field": "category", "operator": "==", "value": "solar"}}
        result = multi_retriever.run(["energy"], kwargs)
        assert "documents" in result
        assert all(doc.meta.get("category") == "solar" for doc in result["documents"])

    @pytest.mark.skipif(
        not os.environ.get("OPENAI_API_KEY", None),
        reason="Export an env var called OPENAI_API_KEY containing the OpenAI API key to run this test.",
    )
    @pytest.mark.integration
    def test_pipeline_integration(self, document_store_with_embeddings):
        expander = QueryExpander(
            chat_generator=OpenAIChatGenerator(model="gpt-4.1-mini"), n_expansions=3, include_original_query=True
        )
        in_memory_retriever = InMemoryEmbeddingRetriever(document_store=document_store_with_embeddings)
        query_embedder = SentenceTransformersTextEmbedder(model="sentence-transformers/all-MiniLM-L6-v2")
        multiquery_retriever = MultiQueryEmbeddingRetriever(
            retriever=in_memory_retriever, query_embedder=query_embedder, max_workers=3
        )

        pipeline = Pipeline()
        pipeline.add_component("query_expander", expander)
        pipeline.add_component("multiquery_retriever", multiquery_retriever)
        pipeline.connect("query_expander.queries", "multiquery_retriever.queries")

        data = {
            "query_expander": {"query": "green energy sources"},
            "multiquery_retriever": {"retriever_kwargs": {"top_k": 3}},
        }
        results = pipeline.run(data=data, include_outputs_from={"query_expander", "multiquery_retriever"})

        assert "multiquery_retriever" in results
        assert "documents" in results["multiquery_retriever"]
        assert len(results["multiquery_retriever"]["documents"]) > 0
        assert "query_expander" in results
        assert "queries" in results["query_expander"]
        assert len(results["query_expander"]["queries"]) == 4

        # assert that documents are sorted by score (highest first)
        scores = [doc.score for doc in results["multiquery_retriever"]["documents"] if doc.score is not None]
        assert scores == sorted(scores, reverse=True)

        # assert there are not duplicates
        contents = [doc.content for doc in results["multiquery_retriever"]["documents"]]
        assert len(contents) == len(set(contents))
